rm(list=ls())
library(FactorHet)
library(tidyverse)
library(cjbart)

set.seed(123)

repdata_nonhisp <- readRDS('code/packaged_data.RDS')

###Build cjbart model
#First make vector of attributes/factors
attributes <- c("Ed", "Gender", "Country", "Reason", "Job", "Exp", "Plans", "Trips", "Lang")
moderators <- c("scale_hisp_prej_flip", "party_ID", "ppEducat", "census_div", "ppEthm")
data_use <- repdata_nonhisp[,c("Chosen_Immigrant", "CaseID", "contest_no", attributes, moderators)]

#Turn attributes into character columns
data_use[,attributes] <- lapply(data_use[,attributes], as.character)
cjmodel <- cjbart(data=data_use, Y="Chosen_Immigrant", id="CaseID", 
                  round = "contest_no", ntree = 15, numcut = 20)

#Get heterogeneous IMCEs
#Reformat baselines list
baseline_list <- c("No formal", "female", "Germany", "Family", "Janitor", 
                   "No job training or prior experience", 
                   "Will look for work after arriving in the U.S.", 
                   "Never been to the U.S.", "Fluent")
het_effects <- IMCE(data = data_use, 
                    model = cjmodel, 
                    attribs = 'Country',
                    ref_levels = 'Germany',
                    cores = 1)
#Check effects
plot(het_effects, covar = "party_ID", plot_levels = c("Iraq","France", "Mexico"))

###FactorHet
#Read in model from main analysis
estimated_models <- readRDS('main_models/out_BIC.RDS')
est_mbo3 <- estimated_models[[3]]
#Get AME estimates
est_AME <- AME(est_mbo3)
#Get IMCEs
facthet_ind <- HTE_by_individual(object = est_mbo3, AME = est_AME)

#Reformat to explore effects for country
country_levels <- levels(repdata_nonhisp$Country)
country_levels <- country_levels[-1]

facthet_ind_country <- facthet_ind$individual
facthet_ind_country <- subset(facthet_ind$individual, factor == "Country")
facthet_ind_country2 <- facthet_ind_country[,c("group", "est", "level")]
facthet_ind_country2$group <- as.numeric(facthet_ind_country$group)

het_effects_country <- het_effects$imce
het_effects_country <- het_effects_country[,c("CaseID", country_levels)]

het_effects_country2 <- pivot_longer(
  het_effects_country,
  cols=all_of(country_levels),
  names_to="level", values_to="est")

#Merge data sets for comparison
merge_imce <- merge(x=facthet_ind_country2, y=het_effects_country2, 
    by.x=c("group", "level"), by.y = c("CaseID", "level"))
names(merge_imce)[which(names(merge_imce)=="est.x")] <- c("est.fh")
names(merge_imce)[which(names(merge_imce)=="est.y")] <- c("est.cjbart")

data_mods <- repdata_nonhisp[,c("CaseID", "scale_hisp_prej_flip", "party_ID", 
                                "ppEducat", "census_div", "ppEthm")]
data_mods2 <- distinct(data_mods)

merge_imce2 <- merge(x = merge_imce, y= data_mods2,
    by.x = "group", by.y = c("CaseID"))

#Boxplot
merge_imce2a <- merge_imce2 %>%
  pivot_longer(cols = c("est.cjbart", "est.fh"), names_to = "method", values_to = "est")
merge_imce2a <- merge_imce2a %>%  mutate(hisp_prej_rank = ntile(scale_hisp_prej_flip,4))

merge_imce2b <- subset(merge_imce2a, level == "Iraq")

method.labs <- c("cjbart", "FactorHet")
names(method.labs) <- c("est.cjbart", "est.fh")

#############FIGURE A1 top#####################
gg_iraq_prej <- ggplot(merge_imce2b, aes(x=as.factor(hisp_prej_rank), 
  y=est, fill=as.factor(hisp_prej_rank)) )+
  geom_boxplot() +
  xlab("Hispanic prejudice score quartile") +
  ylab("CAMCE") +
  facet_wrap(~method, labeller = labeller(method = method.labs)) +
  labs(title="CAMCEs for Iraq relative to Germany") +
  theme(legend.position = "none")
ggsave("figures/compare_cjbart_prej1.pdf", gg_iraq_prej, width = 6, height = 4)

################FIGURE A2 top##################

gg_compare_prej <- ggplot(merge_imce2a, aes(y=level, x=est, color = as.factor(hisp_prej_rank))) +
  geom_boxplot() +
  xlab("Method") +
  ylab("CAMCE") +
  facet_wrap(~method, labeller = labeller(method = method.labs)) +
  labs(color = "Hispanic \nprejudice \nscore \nquartile")
ggsave("figures/compare_cjbart_prej9.pdf", gg_compare_prej, width = 6, height = 4)

#Numerical check
lm_cjbart_prej <- lm(est.cjbart ~ scale_hisp_prej_flip, data = merge_imce2)
lm_facthet_prej <- lm(est.fh ~ scale_hisp_prej_flip, data = merge_imce2)
summary(lm_cjbart_prej)
summary(lm_facthet_prej)

#############FIGURE A1 bottom#####################
levels(merge_imce2b$party_ID)[4] <-"Undecided/ Independent/ Other"
gg_iraq_party <- ggplot(merge_imce2b, aes(x=party_ID, y=est, fill=party_ID) )+
  geom_boxplot() +
  xlab("Party ID") +
  ylab("CAMCE") +
  facet_wrap(~method, labeller = labeller(method = method.labs)) +
  labs(title="CAMCEs for Iraq relative to Germany") +
  theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1),
        legend.position = "none")+
  scale_x_discrete(labels = function(x) str_wrap(x, width = 10))
ggsave("figures/compare_cjbart_party1.pdf", gg_iraq_party, width = 6, height = 4)

################FIGURE A2 bottom##################


merge_imce3 <- subset(merge_imce2, party_ID %in%
                       c("Strong Republican",
                         "Strong Democrat",
                         "Undecided/Independent/Other"))

merge_imce3a <- merge_imce3 %>%
  pivot_longer(cols = c("est.cjbart", "est.fh"), names_to = "method", values_to = "est")
gg_compare_party <- ggplot(merge_imce3a, aes(x=method, y=est, color = party_ID)) +
  geom_boxplot() +
  xlab("Method") +
  ylab("CAMCE") +
  facet_wrap(~level) +
  labs(color = "Party ID") +
  scale_color_discrete(
    labels = c("Strong Republican" = "Strong \nRepublican", 
              "Strong Democrat" = "Strong \nDemocrat", 
              "Undecided/Independent/Other" = "Undecided/ \nIndependent/ \nOther"))+
  scale_x_discrete(labels=c("est.cjbart" = "cjbart", "est.fh" = "FactorHet"))

ggsave("figures/compare_cjbart_party9.pdf", gg_compare_party, width = 6, height = 4)

#Numerical check that estimated effect varies with PID for FH and not for cjbart
lm_cjbart_party <- lm(est.cjbart ~ party_ID, data = merge_imce3)
lm_facthet_party <- lm(est.fh ~ party_ID, data = merge_imce3)
summary(lm_cjbart_party)
summary(lm_facthet_party)

